"""《 Policy Gradient算法实现》
    时间：2024.12
    环境：CartPole-v1
    作者：不去幼儿园
"""
import argparse  # 导入命令行参数解析库
import gym  # 导入OpenAI Gym库，用于创建强化学习环境
import numpy as np  # 导入numpy库，用于处理数值计算
from itertools import count  # 导入count函数，用于生成整数序列

import torch  # 导入PyTorch库
import torch.nn as nn  # 导入PyTorch的神经网络模块
import torch.nn.functional as F  # 导入PyTorch的神经网络功能函数模块
import torch.optim as optim  # 导入PyTorch的优化器模块
from torch.distributions import Categorical  # 导入PyTorch中类别分布的模块，用于离散动作选择

parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')  # 创建一个ArgumentParser对象，描述REINFORCE算法示例
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',  # 添加一个命令行参数：gamma，表示折扣因子（默认0.99）
                    help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N',  # 添加一个命令行参数：seed，用于设置随机种子（默认543）
                    help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',  # 添加一个命令行参数：render，用于决定是否渲染环境
                    help='render the environment')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',  # 添加一个命令行参数：log-interval，表示训练状态日志的输出间隔（默认10）
                    help='interval between training status logs (default: 10)')
args = parser.parse_args()  # 解析命令行参数并将结果保存到args对象中

env = gym.make('CartPole-v1')  # 创建一个CartPole-v1环境，用于训练
torch.manual_seed(args.seed)  # 设置PyTorch的随机种子，以确保可复现性


class Policy(nn.Module):  # 定义一个名为Policy的类，继承自nn.Module，这是一个神经网络模型
    def __init__(self):  # 初始化方法
        super(Policy, self).__init__()  # 调用父类构造方法
        self.affine1 = nn.Linear(4, 128)  # 第一层全连接层，将输入维度4映射到128维
        self.affine2 = nn.Linear(128, 2)  # 第二层全连接层，将输入维度128映射到2维（表示动作空间的大小）

        self.saved_log_probs = []  # 用于保存每一时刻的动作概率的对数
        self.rewards = []  # 用于保存每一时刻的奖励

    def forward(self, x):  # 定义前向传播方法
        x = F.relu(self.affine1(x))  # 经过第一层后使用ReLU激活函数
        action_scores = self.affine2(x)  # 通过第二层得到每个动作的得分
        return F.softmax(action_scores, dim=1)  # 使用softmax函数计算动作的概率分布


policy = Policy()  # 创建Policy类的实例
optimizer = optim.Adam(policy.parameters(), lr=1e-2)  # 使用Adam优化器优化Policy模型，学习率为0.01
eps = np.finfo(np.float32).eps.item()  # 获取float32类型的最小正数，用于避免除零错误


def select_action(state):  # 定义选择动作的函数
    state = torch.from_numpy(state).float().unsqueeze(0)  # 将输入的状态从numpy数组转换为PyTorch张量，并增加一个维度
    probs = policy(state)  # 通过Policy网络计算每个动作的概率分布
    m = Categorical(probs)  # 用Categorical分布定义一个概率分布对象
    action = m.sample()  # 从该分布中采样一个动作
    policy.saved_log_probs.append(m.log_prob(action))  # 保存该动作的对数概率
    return action.item()  # 返回动作的值


def finish_episode():  # 定义结束一个回合的函数
    R = 0  # 初始化回报R为0
    policy_loss = []  # 初始化用于保存每个动作损失的列表
    rewards = []  # 初始化保存所有回报的列表
    for r in policy.rewards[::-1]:  # 从后往前遍历奖励列表
        R = r + args.gamma * R  # 计算当前时刻的回报（折扣奖励）
        rewards.insert(0, R)  # 将回报插入到列表的开头
    rewards = torch.tensor(rewards)  # 将奖励转换为PyTorch张量
    rewards = (rewards - rewards.mean()) / (rewards.std() + eps)  # 对奖励进行标准化
    for log_prob, reward in zip(policy.saved_log_probs, rewards):  # 遍历每个动作的对数概率和奖励
        policy_loss.append(-log_prob * reward)  # 计算每个动作的损失（负对数概率与标准化奖励的乘积）
    optimizer.zero_grad()  # 清除梯度
    policy_loss = torch.cat(policy_loss).sum()  # 计算所有动作的总损失
    policy_loss.backward()  # 反向传播计算梯度
    optimizer.step()  # 执行一步优化
    del policy.rewards[:]  # 清空保存的奖励列表
    del policy.saved_log_probs[:]  # 清空保存的对数概率列表


if __name__ == '__main__':  # 如果是直接运行该文件（而不是导入）
    running_reward = 10  # 初始化运行奖励
    for i_episode in count(1):  # 从1开始无限循环
        state, _ = env.reset()  # 重置环境并获取初始状态
        for t in range(10000):  # 限制每个回合的最大步数为10000
            action = select_action(state)  # 选择动作
            state, reward, done, _, _ = env.step(action)  # 执行动作并获取下一个状态、奖励等信息
            if args.render:  # 如果设置了渲染选项
                env.render()  # 渲染环境
            policy.rewards.append(reward)  # 保存奖励
            if done:  # 如果回合结束
                break  # 跳出循环

        running_reward = running_reward * 0.99 + t * 0.01  # 更新运行奖励（使用指数加权移动平均）
        finish_episode()  # 结束当前回合并进行学习
        if i_episode % args.log_interval == 0:  # 每log_interval步输出一次日志
            print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
                i_episode, t, running_reward))  # 打印当前回合、回合长度和平均长度
        if running_reward > env.spec.reward_threshold:  # 如果运行奖励超过环境的奖励阈值
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(running_reward, t))  # 打印成功信息
            break  # 结束训练
